{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.4.0+cu118\n",
      "0.19.0+cu118\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import models\n",
    "\n",
    "print(torch.__version__)\n",
    "print(torchvision.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet(\n",
      "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models.feature_extraction import create_feature_extractor\n",
    "backbone = create_feature_extractor(model, {\"layer4.1.bn2\": \"0\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 512, 7, 7])\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand(size=(1, 3, 224, 224), dtype=torch.float)\n",
    "res = backbone(x)['0']\n",
    "print(res.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1000])\n"
     ]
    }
   ],
   "source": [
    "# print(backbone)\n",
    "print(model(x).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MobileNetV3(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2dNormActivation(\n",
      "      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "      (2): Hardswish()\n",
      "    )\n",
      "    (1): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)\n",
      "          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (2): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)\n",
      "          (1): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (2): Conv2dNormActivation(\n",
      "          (0): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (3): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)\n",
      "          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (2): Conv2dNormActivation(\n",
      "          (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (4): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False)\n",
      "          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (5): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)\n",
      "          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (6): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)\n",
      "          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): ReLU(inplace=True)\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (7): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False)\n",
      "          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): Conv2dNormActivation(\n",
      "          (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (8): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(200, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False)\n",
      "          (1): BatchNorm2d(200, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): Conv2dNormActivation(\n",
      "          (0): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (9): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)\n",
      "          (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): Conv2dNormActivation(\n",
      "          (0): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (10): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)\n",
      "          (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): Conv2dNormActivation(\n",
      "          (0): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (11): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(480, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False)\n",
      "          (1): BatchNorm2d(480, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(112, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (12): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False)\n",
      "          (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(112, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (13): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False)\n",
      "          (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (14): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)\n",
      "          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (15): InvertedResidual(\n",
      "      (block): Sequential(\n",
      "        (0): Conv2dNormActivation(\n",
      "          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (1): Conv2dNormActivation(\n",
      "          (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)\n",
      "          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "          (2): Hardswish()\n",
      "        )\n",
      "        (2): SqueezeExcitation(\n",
      "          (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "          (fc1): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (activation): ReLU()\n",
      "          (scale_activation): Hardsigmoid()\n",
      "        )\n",
      "        (3): Conv2dNormActivation(\n",
      "          (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (16): Conv2dNormActivation(\n",
      "      (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
      "      (2): Hardswish()\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=960, out_features=1280, bias=True)\n",
      "    (1): Hardswish()\n",
      "    (2): Dropout(p=0.2, inplace=True)\n",
      "    (3): Linear(in_features=1280, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "mobile_model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)\n",
    "print(mobile_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SSD(Single Shot MultiBox Detector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace=True)\n",
      "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (6): ReLU(inplace=True)\n",
      "    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): ReLU(inplace=True)\n",
      "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace=True)\n",
      "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (13): ReLU(inplace=True)\n",
      "    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (15): ReLU(inplace=True)\n",
      "    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (18): ReLU(inplace=True)\n",
      "    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (20): ReLU(inplace=True)\n",
      "    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (22): ReLU(inplace=True)\n",
      "    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (25): ReLU(inplace=True)\n",
      "    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (27): ReLU(inplace=True)\n",
      "    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (29): ReLU(inplace=True)\n",
      "    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Dropout(p=0.5, inplace=False)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU(inplace=True)\n",
      "    (5): Dropout(p=0.5, inplace=False)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)\n",
    "# model = models.vgg16()\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU(inplace=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (5): ReLU(inplace=True)\n",
      "    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (9): ReLU(inplace=True)\n",
      "    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (12): ReLU(inplace=True)\n",
      "    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (16): ReLU(inplace=True)\n",
      "    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (19): ReLU(inplace=True)\n",
      "    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (22): ReLU(inplace=True)\n",
      "    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (26): ReLU(inplace=True)\n",
      "    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (29): ReLU(inplace=True)\n",
      "    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (32): ReLU(inplace=True)\n",
      "    (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (36): ReLU(inplace=True)\n",
      "    (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (39): ReLU(inplace=True)\n",
      "    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (42): ReLU(inplace=True)\n",
      "    (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Dropout(p=0.5, inplace=False)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU(inplace=True)\n",
      "    (5): Dropout(p=0.5, inplace=False)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = models.vgg16_bn(weights=models.VGG16_BN_Weights.DEFAULT)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### listing and retrieving available models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_models = models.list_models()\n",
    "clams = models.list_models(module=torchvision.models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "m1 = models.get_model(\"mobilenet_v3_large\", weights=None)\n",
    "m2 = models.get_model(\"quantized_mobilenet_v3_large\", weights=\"DEFAULT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = models.get_weight(\"MobileNet_V3_Large_Weights.DEFAULT\")\n",
    "assert weights == models.MobileNet_V3_Large_Weights.DEFAULT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_enum = models.get_model_weights(\"quantized_mobilenet_v3_large\")\n",
    "assert weights_enum == models.quantization.MobileNet_V3_Large_QuantizedWeights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_enum2 = models.get_model_weights(models.quantization.mobilenet_v3_large)\n",
    "assert weights_enum2 == weights_enum"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using torch hub\n",
    "`不用安装torchvision，就可以使用model参数`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/xiaoliang/.cache/torch/hub/pytorch_vision_main\n",
      "Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /home/xiaoliang/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n",
      "100%|██████████| 97.8M/97.8M [06:31<00:00, 262kB/s]\n"
     ]
    }
   ],
   "source": [
    "model = torch.hub.load(\"pytorch/vision\", \"resnet50\", weights=\"DEFAULT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# weights = torch.hub.load(\"pytorch/vision\", \"get_weight\", weights=\"ResNet50_Weights.DEFAULT\")\n",
    "# model = torch.hub.load(\"pytorch/vision\", \"resnet50\", weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ResNet50_Weights.IMAGENET1K_V1, ResNet50_Weights.IMAGENET1K_V2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/xiaoliang/.cache/torch/hub/pytorch_vision_main\n"
     ]
    }
   ],
   "source": [
    "weight_enum = torch.hub.load(\"pytorch/vision\", \"get_model_weights\", name=\"resnet50\")\n",
    "print([weight for weight in weight_enum])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transforming_augmenting_images.ipynb  backbone_change_case.ipynb\n",
      "military uniform: 10.0%\n"
     ]
    }
   ],
   "source": [
    "from torchvision.io import read_image\n",
    "from torchvision.models import resnet50, ResNet50_Weights\n",
    "%ls\n",
    "img = read_image(\"/home/xiaoliang/workspace/python/advanceDL/data/assets/encode_jpeg/grace_hopper_517x606.jpg\")\n",
    "\n",
    "# Step 1: Initialize model with the best available weights\n",
    "weights = ResNet50_Weights.DEFAULT\n",
    "model = resnet50(weights=weights)\n",
    "model.eval()\n",
    "\n",
    "# Step 2: Initialize the inference transforms\n",
    "preprocess = weights.transforms()\n",
    "\n",
    "# Step 3: Apply inference preprocessing transforms\n",
    "batch = preprocess(img).unsqueeze(0)\n",
    "\n",
    "# Step 4: Use the model and print the predicted category\n",
    "prediction = model(batch).squeeze(0).softmax(0)\n",
    "class_id = prediction.argmax().item()\n",
    "score = prediction[class_id].item()\n",
    "category_name = weights.meta[\"categories\"][class_id]\n",
    "print(f\"{category_name}: {100 * score:.1f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "d2l",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
